Toxic Comment Filter

BiLSTM model to make a multi label classification for a toxic comment filter
code
Deep Learning
Python, R
Author

Simone Brazzi

Published

August 2, 2024

Introduction

  • Costruire un modello in grado di filtrare i commenti degli utenti in base al grado di dannosità del linguaggio.
  • Preprocessare il testo eliminando l’insieme di token che non danno contributo significativo a livello semantico.
  • Trasformare il corpus testuale in sequenze.
  • Costruire un modello di Deep Learning comprendente dei layer ricorrenti per un task di classificazione multilabel.

In prediction time, il modello deve ritornare un vettore contenente un 1 o uno 0 in corrispondenza di ogni label presente nel dataset (toxic, severe_toxic, obscene, threat, insult, identity_hate). In questo modo, un commento non dannoso sarà classificato da un vettore di soli 0 [0,0,0,0,0,0]. Al contrario, un commento pericoloso presenterà almeno un 1 tra le 6 labels.

Setup

Leveraging Quarto and RStudio, I will setup an R and Python enviroment.

Import R libraries

Import R libraries. These will be used for both the rendering of the document and data analysis. The reason is I prefer ggplot2 over matplotlib. I will also use colorblind safe palettes.

Code
library(tidyverse, verbose = FALSE)
library(tidymodels, verbose = FALSE)
library(reticulate)
library(ggplot2)
library(plotly)
library(RColorBrewer)
library(bslib)
library(Metrics)

reticulate::use_virtualenv("r-tf")

Import Python packages

Code
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
import keras
import keras_nlp

from keras.backend import clear_session
from keras.models import Model, load_model
from keras.layers import TextVectorization, Input, Dense, Embedding, Dropout, GlobalAveragePooling1D, LSTM, Bidirectional, GlobalMaxPool1D, Flatten, Attention
from keras.metrics import Precision, Recall, AUC, SensitivityAtSpecificity, SpecificityAtSensitivity, F1Score


from sklearn.model_selection import train_test_split, KFold
from sklearn.metrics import multilabel_confusion_matrix, classification_report, ConfusionMatrixDisplay, precision_recall_curve, f1_score, recall_score, roc_auc_score

Create a Config class to store all the useful parameters for the model and for the project.

Class Config

I created a class with all the basic configuration of the model, to improve the readability.

Code
class Config():
    def __init__(self):
        self.url = "https://s3.eu-west-3.amazonaws.com/profession.ai/datasets/Filter_Toxic_Comments_dataset.csv"
        self.max_tokens = 20000
        self.output_sequence_length = 911 # check the analysis done to establish this value
        self.embedding_dim = 128
        self.batch_size = 32
        self.epochs = 100
        self.temp_split = 0.3
        self.test_split = 0.5
        self.random_state = 42
        self.total_samples = 159571 # total train samples
        self.train_samples = 111699
        self.val_samples = 23936
        self.features = 'comment_text'
        self.labels = ['toxic', 'severe_toxic', 'obscene', 'threat', 'insult', 'identity_hate']
        self.new_labels = ['toxic', 'severe_toxic', 'obscene', 'threat', 'insult', 'identity_hate', "clean"]
        self.label_mapping = {label: i for i, label in enumerate(self.labels)}
        self.new_label_mapping = {label: i for i, label in enumerate(self.labels)}
        self.path = "/Users/simonebrazzi/R/blog/posts/toxic_comment_filter/history/f1score/"
        self.model =  self.path + "model_f1.keras"
        self.checkpoint = self.path + "checkpoint.lstm_model_f1.keras"
        self.history = self.path + "lstm_model_f1.xlsx"
        
        self.metrics = [
            Precision(name='precision'),
            Recall(name='recall'),
            AUC(name='auc', multi_label=True, num_labels=len(self.labels)),
            F1Score(name="f1", average="macro")
            
        ]
    def get_early_stopping(self):
        early_stopping = keras.callbacks.EarlyStopping(
            monitor="val_f1", # "val_recall",
            min_delta=0.2,
            patience=10,
            verbose=0,
            mode="max",
            restore_best_weights=True,
            start_from_epoch=3
        )
        return early_stopping

    def get_model_checkpoint(self, filepath):
        model_checkpoint = keras.callbacks.ModelCheckpoint(
            filepath=filepath,
            monitor="val_f1", # "val_recall",
            verbose=0,
            save_best_only=True,
            save_weights_only=False,
            mode="max",
            save_freq="epoch"
        )
        return model_checkpoint

    def find_optimal_threshold_cv(self, ytrue, yproba, metric, thresholds=np.arange(.05, .35, .05), n_splits=7):

      # instantiate KFold
      kf = KFold(n_splits=n_splits, shuffle=True, random_state=42)
      threshold_scores = []

      for threshold in thresholds:

        cv_scores = []
        for train_index, val_index in kf.split(ytrue):

          ytrue_val = ytrue[val_index]
          yproba_val = yproba[val_index]

          ypred_val = (yproba_val >= threshold).astype(int)
          score = metric(ytrue_val, ypred_val, average="macro")
          cv_scores.append(score)

        mean_score = np.mean(cv_scores)
        threshold_scores.append((threshold, mean_score))

        # Find the threshold with the highest mean score
        best_threshold, best_score = max(threshold_scores, key=lambda x: x[1])
      return best_threshold, best_score

config = Config()

Data

The dataset is accessible using tf.keras.utils.get_file to get the file from the url. N.B. For reproducibility purpose, I also downloaded the dataset. There was time in which the link was not available.

Code
# df = pd.read_csv(config.path)
file = tf.keras.utils.get_file("Filter_Toxic_Comments_dataset.csv", config.url)
df = pd.read_csv(file)
Code
library(reticulate)

py$df %>%
  tibble() %>% 
  head(5)
Table 1: First 5 elemtns
# A tibble: 5 × 8
  comment_text            toxic severe_toxic obscene threat insult identity_hate
  <chr>                   <dbl>        <dbl>   <dbl>  <dbl>  <dbl>         <dbl>
1 "Explanation\nWhy the …     0            0       0      0      0             0
2 "D'aww! He matches thi…     0            0       0      0      0             0
3 "Hey man, I'm really n…     0            0       0      0      0             0
4 "\"\nMore\nI can't mak…     0            0       0      0      0             0
5 "You, sir, are my hero…     0            0       0      0      0             0
# ℹ 1 more variable: sum_injurious <dbl>

Lets create a clean variable for EDA purpose: I want to visually see how many observation are clean vs the others labels.

Code
df.loc[df.sum_injurious == 0, "clean"] = 1
df.loc[df.sum_injurious != 0, "clean"] = 0

EDA

First a check on the dataset to find possible missing values and imbalances.

Frequency

Code
library(reticulate)
df_r <- py$df
new_labels_r <- py$config$new_labels

df_r_grouped <- df_r %>% 
  select(all_of(new_labels_r)) %>%
  pivot_longer(
    cols = all_of(new_labels_r),
    names_to = "label",
    values_to = "value"
  ) %>% 
  group_by(label) %>%
  summarise(count = sum(value)) %>% 
  mutate(freq = round(count / sum(count), 4))

df_r_grouped
Table 2: Absolute and relative labels frequency
# A tibble: 7 × 3
  label          count   freq
  <chr>          <dbl>  <dbl>
1 clean         143346 0.803 
2 identity_hate   1405 0.0079
3 insult          7877 0.0441
4 obscene         8449 0.0473
5 severe_toxic    1595 0.0089
6 threat           478 0.0027
7 toxic          15294 0.0857

Barchart

Code
library(reticulate)
barchart <- df_r_grouped %>%
  ggplot(aes(x = reorder(label, count), y = count, fill = label)) +
  geom_col() +
  labs(
    x = "Labels",
    y = "Count"
  ) +
  # sort bars in descending order
  scale_x_discrete(limits = df_r_grouped$label[order(df_r_grouped$count, decreasing = TRUE)]) +
  scale_fill_brewer(type = "seq", palette = "RdYlBu")
ggplotly(barchart)
Figure 1: Imbalance in the dataset with clean variable

It is visible how much the dataset in imbalanced. This means it could be useful to check for the class weight and use this argument during the training.

It is clear that most of our text are clean. We are talking about 0.8033 of the observations which are clean. Only 0.1967 are toxic comments.

Sequence lenght definition

To convert the text in a useful input for a NN, it is necessary to use a TextVectorization layer. See the Section 4 section.

One of the method is output_sequence_length: to better define it, it is useful to analyze our text length. To simulate what the model we do, we are going to remove the punctuation and the new lines from the comments.

Summary

Code
library(reticulate)
df_r %>% 
  mutate(
    comment_text_clean = comment_text %>%
      tolower() %>% 
      str_remove_all("[[:punct:]]") %>% 
      str_replace_all("\n", " "),
    text_length = comment_text_clean %>% str_count()
    ) %>% 
  pull(text_length) %>% 
  summary() %>% 
  as.list() %>% 
  as_tibble()
Table 3: Summary of text length
# A tibble: 1 × 6
   Min. `1st Qu.` Median  Mean `3rd Qu.`  Max.
  <dbl>     <dbl>  <dbl> <dbl>     <dbl> <dbl>
1     4        91    196  378.       419  5000

Boxplot

Code
library(reticulate)
boxplot <- df_r %>% 
  mutate(
    comment_text_clean = comment_text %>%
      tolower() %>% 
      str_remove_all("[[:punct:]]") %>% 
      str_replace_all("\n", " "),
    text_length = comment_text_clean %>% str_count()
    ) %>% 
  # pull(text_length) %>% 
  ggplot(aes(y = text_length)) +
  geom_boxplot() +
  theme_minimal()
ggplotly(boxplot)
Figure 2: Text length boxplot

Histogram

Code
library(reticulate)
df_ <- df_r %>% 
  mutate(
    comment_text_clean = comment_text %>%
      tolower() %>% 
      str_remove_all("[[:punct:]]") %>% 
      str_replace_all("\n", " "),
    text_length = comment_text_clean %>% str_count()
  )

Q1 <- quantile(df_$text_length, 0.25)
Q3 <- quantile(df_$text_length, 0.75)
IQR <- Q3 - Q1
upper_fence <- as.integer(Q3 + 1.5 * IQR)

histogram <- df_ %>% 
  ggplot(aes(x = text_length)) +
  geom_histogram(bins = 50) +
  geom_vline(aes(xintercept = upper_fence), color = "red", linetype = "dashed", linewidth = 1) +
  theme_minimal() +
  xlab("Text Length") +
  ylab("Frequency") +
  xlim(0, max(df_$text_length, upper_fence))
ggplotly(histogram)
Figure 3: Text length histogram with boxplot upper fence

Considering all the above analysis, I think a good starting value for the output_sequence_length is 911, the upper fence of the boxplot. In the last plot, it is the dashed red vertical line.. Doing so, we are removing the outliers, which are a small part of our dataset.

Dataset

Now we can split the dataset in 3: train, test and validation sets. Considering there is not a function in sklearn which lets split in these 3 sets, we can do the following: - split between a train and temporary set with a 0.3 split. - split the temporary set in 2 equal sized test and val sets.

Code
x = df[config.features].values
y = df[config.labels].values

xtrain, xtemp, ytrain, ytemp = train_test_split(
  x,
  y,
  test_size=config.temp_split, # .3
  random_state=config.random_state
  )
xtest, xval, ytest, yval = train_test_split(
  xtemp,
  ytemp,
  test_size=config.test_split, # .5
  random_state=config.random_state
  )

xtrain shape: py$xtrain.shape ytrain shape: py$ytrain.shape xtest shape: py$xtest.shape ytest shape: py$ytest.shape xval shape: py$xval.shape yval shape: py$yval.shape

The datasets are created using the tf.data.Dataset function. It creates a data input pipeline. The tf.data API makes it possible to handle large amounts of data, read from different data formats, and perform complex transformations. The tf.data.Dataset is an abstraction that represents a sequence of elements, in which each element consists of one or more components. Here each dataset is creates using from_tensor_slices. It create a tf.data.Dataset from a tuple (features, labels). .batch let us work in batches to improve performance, while .prefetch overlaps the preprocessing and model execution of a training step. While the model is executing training step s, the input pipeline is reading the data for step s+1. Check the documentation for further informations.

Code
train_ds = (
    tf.data.Dataset
    .from_tensor_slices((xtrain, ytrain))
    .shuffle(xtrain.shape[0])
    .batch(config.batch_size)
    .prefetch(tf.data.experimental.AUTOTUNE)
)

test_ds = (
    tf.data.Dataset
    .from_tensor_slices((xtest, ytest))
    .batch(config.batch_size)
    .prefetch(tf.data.experimental.AUTOTUNE)
)

val_ds = (
    tf.data.Dataset
    .from_tensor_slices((xval, yval))
    .batch(config.batch_size)
    .prefetch(tf.data.experimental.AUTOTUNE)
)
Code
print(
  f"train_ds cardinality: {train_ds.cardinality()}\n",
  f"val_ds cardinality: {val_ds.cardinality()}\n",
  f"test_ds cardinality: {test_ds.cardinality()}\n"
  )
train_ds cardinality: 3491
 val_ds cardinality: 748
 test_ds cardinality: 748

Check the first element of the dataset to be sure that the preprocessing is done correctly.

Code
train_ds.as_numpy_iterator().next()
(array([b"Although interesting, this is one of the stranger proposals to deal with this issue.  It's like saying BLP vios are okay, if only done in small font.  The images themselves are a significant part of the issue, and covering the event without using the images is incomplete.  We can make a conscious decision to sacrifice encyclopedia quality for the goal of free content, and in so doing diverge from most every other information source int he world, but if so that's what it is.",
       b'Calvin Harris (album) \n\nI have nominationed Calvin Harris (album) for deletion at Articles for deletion as it was the article creator  who removed the original PROD tag when removing my CSD tag, rather than me when adding the CSD tag. As such, it counts as PROD contest, which means the article no longer qualify for PROD. Thanks!',
       b", I'm actually supporting it now based on that one book, and the point  made at the beginning of this conversation about her heritage.",
       b'"\nI don\'t know if those TOC links are entirely useful. Wikipedia needs to come up with a way to link to items within a large table like that. Perhaps we could use the HTML code with links with the # like that?\n\n \xe2\x89\x88talk "',
       b'Piss off, she is an ignorant bitch.',
       b'"\n\nHey, nobody\'s wrong here, and nobody is defininitely a loser. I think that user Abhishek191288 reverted Markshen1985\'s edits because there are some opinions to his edits such as  logo is a brilliant red kapok delicately adoring a blue vertical tail fin. Also, much of ""History and Development is unsourced. Having said that, Markshen1985\'s work is remarkable. Also, there seems to be a touch of discrimination here, with ""Chinese"" being thrown around. I\'m going report this to .  (Talk) (Contributions)(Feed back needed @ Talk page) "',
       b'He means redirect.',
       b'. i would also like it to be noted that Amultic was the one who gave me my previous warnings after defending my self against a personnel attack by another user seehere ignoring the other users remarks, i question the merits of those warning which lead to this.',
       b"Oh wait, you're an American asshole... Let me explain how the law here work, if telenet uncovers who I am, they are violating several laws. I have free speech here.",
       b'gay \n\nwhy you so gay? you hiding something, fatty?',
       b'Article is on editing process... \n\nGreetings to all Editors!\n\nThis article is on-going process of revision and editing, please give us more time to work for it. The person in this article is notable, he do have a lot of contributions and achievement not only here at Asia, Europe and across America. I included all the possible references and links to support his notability. Any consideration to this is highly appreciated and with thanks. Best Regards. Turki Faisal Al Rasheed',
       b"Does anyone know if he performed in woodstock 1970.  They make Woodstock 70 a bold title.  And by boldness i do mean there was a woodstock 70.  i can't remember all of the performers and they are understated. i know some not enough.  i always thought the other players kept track better than me. None have come forward.  Could Jackson C.\nFrank be a link?",
       b'"\nMONKEY MONKEY MONKEY MONKEY MONKEY  \xe2\x80\x94Preceding unsigned comment added by 210.43.128.18   "',
       b'COOL ANYTHING ELSE D, are you gonna report me for being sarcastic as well?',
       b"~~Assuming that the first sentence in the Casablanca article is NOT supposed to be that it's the worst movie ever, someone has altered it to say so.  I don't have time right now to read the entire article to see if anything else has been childishly tampered with, but I thought I'd try to point it out.  If no one has fixed it up in a few days I'll try to wander back, though I know very little about the movie (I ended up here because I wanted info on the movie, not because I had it)  Cheers  (24 Nov 06)",
       b'Let me go further, this organization has been around for at least a decade. Given traditional high-level turnover rate of campaign staff there are probably (50 states x 10 years x ~5 state director turnover rate) ~2,500 former AFD state Exec Directors running around. Are we going to add articles on each one that gets mentioned in RS?',
       b'It was a fine tree.',
       b"By saying Jubilees is not canonical in any mainstream denomination, you are revealing your abject ignorance and bias POV, because while your personal religion may not consider it canonical, the Ethiopian Orthodox Christians and Ethiopian Jews do, and they are Abrahamic.  The Orthodox Christians make up a majority in Ethiopia.  The Constantinian Christians in the Roman Empire,  and the Pharisee Sanhedrin both tried to do away with Jubilees, but wikipedia is neutral and does not subscribe to any POV in disputes like that, nor will it assist you in attemting to marginalize Ethiopians' religious beliefs.",
       b'February 2015 (UTC)directly where relevant.   02:49, 18',
       b'" July 2006 (UTC)\n\nI haven\'t attacked anyone on any page.   23:05, 24"',
       b'But Abeg92 is a faggot.',
       b"Sir Gawain and the Green Knight\nThanks for bringing this to my attention. I've alerted the editor who has done the most to develop the article over the past several months, and I personally will take a more thorough look when I have time. t/c",
       b"Homage to Mussorgsky: Epithalamium \n\nAm I the only Wikipedian who was at the premi\xc3\xa8re? If not, please share your thoughts. As it happens, my father is the pianist's agent, so that's hwy I was there: I wasn't sure what to think of the piece - I preferred the Lyapunov Transcendental Studies that Malcolm Binns also played that night.Vox  8'",
       b'VfD\nOn May 4, 2005, this article was nominated for deletion.  The result was merge/redirect.  See Wikipedia:Votes for deletion/Boc for a record of the discussion.  (spill yours?)',
       b'Had a look. Started making changes. Noticed WP:Copyvio, logged my concern via helpme and page got deleted. See John_W._R._Taylor',
       b'Thank you for your patience. I hope all of this will end well. 92.36.173.254',
       b'Hi I have found some more links for you http://www.wwfindia.org/about_wwf/what_we_do/freshwater_wetlands/our_work/ramsar_sites/harike_wetlands_.cfm and  http://www.birding.in/birdingsites/harike_lake.htm and also http://www.gisdevelopment.net/aars/acrs/1997/ts7/ts7008.asp and finally a brilliant article on the History http://www.punjabheritage.org/architectural-heritage/local-enthusiasm-but-official-neglect-for-anglo-sikh-war-monuments2704.html',
       b'"\n\n Yes, I noticed, the text, mostly in the section ""Lennox Crisis"" is a summary of this article, Macauley, Sarah. \'The Lennox Crisis\', 1558-1563.\', Northern History 41.2 (2004). I can\'t tell if this text a good summary of that article, or if that article is useful from this summary.  "',
       b'YOUR A JERK!!!!!!! \n\nYou are a jerk yo lazy butt!',
       b"thanks,  it appears that you didn't read any of the links above. here i will provide the key one again: Wikipedia:Conflict_of_interest#Writing_about_yourself_and_your_work. i think your response would have been different, had you read that.  would you please read that, and reply again?  thanks.",
       b'"\nResponse much appreciated, thanks. Well, I\'m trying to retain as much of your text as possible. There\'s a lot of good stuff in there which simply needs sourcing and updating in with the other material. That\'s why I abandoned the sandbox idea and thought it best to just run with what we have. The article is going to get massive in the next few weeks though to make it comprehensive as possible, but it\'ll be split and condensed later. I really do need your help if possible though on tracing the source of the unsourced material. I\'ll try to do what I can but I\'ll approach you if I can\'t find a source for something.\xe2\x99\xa6  "',
       b"To: User:Protonk, treating people like they are children will not get you any respect from free thinking individuals and I'd like you to know you're up for review."],
      dtype=object), array([[0, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0, 0],
       [1, 1, 1, 0, 1, 0],
       [0, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0, 0],
       [1, 0, 1, 0, 1, 1],
       [1, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0, 0],
       [1, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0, 0],
       [1, 0, 1, 0, 1, 1],
       [0, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0, 0],
       [1, 0, 1, 0, 1, 0],
       [0, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0, 0]]))

And we check also the shape. We expect a feature of shape (batch, ) and a target of shape (batch, number of labels).

Code
print(
  f"text train shape: {train_ds.as_numpy_iterator().next()[0].shape}\n",
  f" text train type: {train_ds.as_numpy_iterator().next()[0].dtype}\n",
  f"label train shape: {train_ds.as_numpy_iterator().next()[1].shape}\n",
  f"label train type: {train_ds.as_numpy_iterator().next()[1].dtype}\n"
  )
text train shape: (32,)
  text train type: object
 label train shape: (32, 6)
 label train type: int64

Preprocessing

Of course preprocessing! Text is not the type of input a NN can handle. The TextVectorization layer is meant to handle natural language inputs. The processing of each example contains the following steps: 1. Standardize each example (usually lowercasing + punctuation stripping) 2. Split each example into substrings (usually words) 3. Recombine substrings into tokens (usually ngrams) 4. Index tokens (associate a unique int value with each token) 5. Transform each example using this index, either into a vector of ints or a dense float vector.

For more reference, see the documentation at the following link.

Code
text_vectorization = TextVectorization(
  max_tokens=config.max_tokens,
  standardize="lower_and_strip_punctuation",
  split="whitespace",
  output_mode="int",
  output_sequence_length=config.output_sequence_length,
  pad_to_max_tokens=True
  )

# prepare a dataset that only yields raw text inputs (no labels)
text_train_ds = train_ds.map(lambda x, y: x)
# adapt the text vectorization layer to the text data to index the dataset vocabulary
text_vectorization.adapt(text_train_ds)

This layer is set to: - max_tokens: 20000. It is common for text classification. It is the maximum size of the vocabulary for this layer. - output_sequence_length: 911. See Figure 3 for the reason why. Only valid in "int" mode. - output_mode: outputs integer indices, one integer index per split string token. When output_mode == “int”, 0 is reserved for masked locations; this reduces the vocab size to max_tokens - 2 instead of max_tokens - 1. - standardize: "lower_and_strip_punctuation". - split: on whitespace.

To preserve the original comments as text and also have a tf.data.Dataset in which the text is preprocessed by the TextVectorization function, it is possible to map it to the features of each dataset.

Code
processed_train_ds = train_ds.map(
    lambda x, y: (text_vectorization(x), y),
    num_parallel_calls=tf.data.experimental.AUTOTUNE
)
processed_val_ds = val_ds.map(
    lambda x, y: (text_vectorization(x), y),
    num_parallel_calls=tf.data.experimental.AUTOTUNE
)
processed_test_ds = test_ds.map(
    lambda x, y: (text_vectorization(x), y),
    num_parallel_calls=tf.data.experimental.AUTOTUNE
)

Model

Definition

Define the model using the Functional API.

Code
def get_deeper_lstm_model():
    clear_session()
    inputs = Input(shape=(None,), dtype=tf.int64, name="inputs")
    embedding = Embedding(
        input_dim=config.max_tokens,
        output_dim=config.embedding_dim,
        mask_zero=True,
        name="embedding"
    )(inputs)
    x = Bidirectional(LSTM(256, return_sequences=True, name="bilstm_1"))(embedding)
    x = Bidirectional(LSTM(128, return_sequences=True, name="bilstm_2"))(x)
    # Global average pooling
    x = GlobalAveragePooling1D()(x)
    # Add regularization
    x = Dropout(0.3)(x)
    x = Dense(64, activation='relu', kernel_regularizer=tf.keras.regularizers.l2(0.01))(x)
    x = LayerNormalization()(x)
    outputs = Dense(len(config.labels), activation='sigmoid', name="outputs")(x)
    model = Model(inputs, outputs)
    model.compile(optimizer='adam', loss="binary_crossentropy", metrics=config.metrics, steps_per_execution=32)
    
    return model

lstm_model = get_deeper_lstm_model()
lstm_model.summary()

Callbacks

Finally, the model has been trained using 2 callbacks: - Early Stopping, to avoid to consume the kaggle GPU time. - Model Checkpoint, to retrieve model training information.

Code
# callbacks
my_es = config.get_early_stopping()
my_mc = config.get_model_checkpoint(filepath="/checkpoint.keras")
callbacks = [my_es, my_mc]

Final preparation before fit

Considering the dataset is imbalanced, to increase the performance we need to calculate the class weight. This will be passed during the training of the model.

Code
lab = pd.DataFrame(columns=config.labels, data=ytrain)
r = lab.sum() / len(ytrain)
class_weight = dict(zip(range(len(config.labels)), r))
df_class_weight = pd.DataFrame.from_dict(
  data=class_weight,
  orient='index',
  columns=['class_weight']
  )
df_class_weight.index = config.labels
Code
library(reticulate)
py$df_class_weight
Table 4: Class weight
              class_weight
toxic          0.095900590
severe_toxic   0.009928468
obscene        0.052757858
threat         0.003061800
insult         0.049132042
identity_hate  0.008710911

It is also useful to define the steps per epoch for train and validation dataset. This step is required to avoid

Code
steps_per_epoch = config.train_samples // config.batch_size
validation_steps = config.val_samples // config.batch_size

Fit

The fit has been done on Kaggle to levarage the GPU. Some considerations about the model:

  • .repeat() ensure the model sees all the daataset.
  • epocs is set to 100.
  • validation_data has the same repeat.
  • callbacks are the one defined before.
  • class_weight ensure the model is trained using the frequency of each class, because our dataset is imbalanced.
  • steps_per_epoch and validation_steps depend on the use of repeat.
Code
history = model.fit(
  processed_train_ds.repeat(),
  epochs=config.epochs,
  validation_data=processed_val_ds.repeat(),
  callbacks=callbacks,
  class_weight=class_weight,
  steps_per_epoch=steps_per_epoch,
  validation_steps=validation_steps
  )

Now we can import the model and the history trained on Kaggle.

Code
model = load_model(filepath=config.model)
history = pd.read_excel(config.history)

Evaluate

Code
validation = model.evaluate(
  processed_val_ds.repeat(),
  steps=validation_steps, # 748
  verbose=0
  )
Code
tibble(
  metric = c("loss", "precision", "recall", "auc", "f1_score"),
  value = py$validation
  )
Table 5: Model validation metric
# A tibble: 5 × 2
  metric     value
  <chr>      <dbl>
1 loss      0.0542
2 precision 0.789 
3 recall    0.671 
4 auc       0.957 
5 f1_score  0.0293

Predict

For the prediction, the model does not need to repeat the dataset, because the model has already been trained and now it has just to consume the data to make the prediction.

Code
predictions = model.predict(processed_test_ds, verbose=0)

Confusion Matrix

The best way to assess the performance of a multi label classification is using a confusion matrix. Sklearn has a specific function to create a multi label classification matrix to handle the fact that there could be multiple labels for one prediction.

Grid Search Cross Validation for best threshold

Grid Search CV is a technique for fine-tuning hyperparameter of a ML model. It systematically search through a set of hyperparamenter values to find the combination which led to the best model performance. In this case, I am using a KFold Cross Validation is a resempling technique to split the data into k consecutive folds. Each fold is used once as a validation while the k - 1 remaining folds are the training set. See the documentation for more information.

The model is trained to optimize the recall. The decision was made because the cost of missing a True Positive is greater than a False Positive. In this case, missing a injurious observation is worst than classifying a clean one as bad.

Having said this, I still want to test different metrics other than the recall_score to have more possibility of decision of the best threshold.

f1_score

Code
ytrue = ytest.astype(int)
y_pred_proba = predictions
optimal_threshold_f1, best_score_f1 = config.find_optimal_threshold_cv(ytrue, y_pred_proba, f1_score)

print(f"Optimal threshold: {optimal_threshold_f1}")
Optimal threshold: 0.15000000000000002
Code
print(f"Best score: {best_score_f1}")
Best score: 0.4788653077945807
Code

# Use the optimal threshold to make predictions
final_predictions_f1 = (y_pred_proba >= optimal_threshold_f1).astype(int)

Optimal threshold f1 score: 0.15. Best score: 0.4788653.

recall_score

Code
ytrue = ytest.astype(int)
y_pred_proba = predictions
optimal_threshold_recall, best_score_recall = config.find_optimal_threshold_cv(ytrue, y_pred_proba, recall_score)

# Use the optimal threshold to make predictions
final_predictions_recall = (y_pred_proba >= optimal_threshold_recall).astype(int)

Optimal threshold recall: 0.05. Best score: 0.8095814.

roc_auc_score

Code
ytrue = ytest.astype(int)
y_pred_proba = predictions
optimal_threshold_roc, best_score_roc = config.find_optimal_threshold_cv(ytrue, y_pred_proba, roc_auc_score)

print(f"Optimal threshold: {optimal_threshold_roc}")
Optimal threshold: 0.05
Code
print(f"Best score: {best_score_roc}")
Best score: 0.8809499649742268
Code

# Use the optimal threshold to make predictions
final_predictions_roc = (y_pred_proba >= optimal_threshold_roc).astype(int)

Optimal threshold roc: 0.05. Best score: 0.88095.

Confusion Matrix Plot

The confusion matrix is plotted using the multilabel_confusion_matrix function in scikit-learn. We have to plot a confusion matrix for each label. To plot the confusion matrix, we need to convert the predicted probability of a label to a proper prediction. To do so, we use the calculated optimal threshold for the recall, which is 0.05. The confusion matrix plotted hete, considering we have a multi label task, is not a big one with all the labels as columns and indices. We plot a confusion matrix for each label with a simple for loop, which extract for each loop the confusion matrix and the associated label.

Code
# convert probability predictions to predictions
ypred = predictions >=  optimal_threshold_f1 # .05
ypred = ypred.astype(int)

# create a plot with 3 by 2 subplots
fig, axes = plt.subplots(3, 2, figsize=(15, 15))
axes = axes.flatten()
mcm = multilabel_confusion_matrix(ytrue, ypred)
# plot the confusion matrices for each label
for i, (cm, label) in enumerate(zip(mcm, config.labels)):
    disp = ConfusionMatrixDisplay(confusion_matrix=cm)
    disp.plot(ax=axes[i], colorbar=False)
    axes[i].set_title(f"Confusion matrix for label: {label}")
plt.tight_layout()
plt.show()
Figure 4: Multi Label Confusion matrix

Classification Report

Code
cr = classification_report(
  ytrue,
  ypred,
  target_names=config.labels,
  digits=4,
  output_dict=True
  )
df_cr = pd.DataFrame.from_dict(cr).reset_index()
Code
library(reticulate)
df_cr <- py$df_cr %>% dplyr::rename(names = index)
cols <- df_cr %>% colnames()
df_cr %>% 
  pivot_longer(
    cols = -names,
    names_to = "metrics",
    values_to = "values"
  ) %>% 
  pivot_wider(
    names_from = names,
    values_from = values
  )
Table 6: Classification report
# A tibble: 10 × 5
   metrics       precision recall `f1-score` support
   <chr>             <dbl>  <dbl>      <dbl>   <dbl>
 1 toxic            0.677  0.840      0.749     2262
 2 severe_toxic     0.328  0.812      0.468      240
 3 obscene          0.655  0.888      0.754     1263
 4 threat           0      0          0           69
 5 insult           0.574  0.849      0.685     1170
 6 identity_hate    0.154  0.372      0.218      207
 7 micro avg        0.584  0.822      0.683     5211
 8 macro avg        0.398  0.627      0.479     5211
 9 weighted avg     0.603  0.822      0.692     5211
10 samples avg      0.0603 0.0764     0.0641    5211

Conclusions

The BiLSTM model optimized to have an high recall is performing good enough to make predictions for each label, except for the threat one. See Table 2 and Figure 1: the threat label is only 0.27 % of the observations. The model has been optimized for recall because the cost of not identifying a injurious comment as such is higher than the cost of considering a clean comment as injurious.

Possibile improvements could be to increase the number of observations, expecially for the threat one. In general there are too many clean comments. This could be avoided doing an undersampling of the clean comment, which I explicitly avoided to check the performance on the BiLSTM with an imbalanced dataset, leveraging the class weight method.